#!/usr/bin/env python3
"""
HTTP/SSE Integration test for the refactored Graphiti MCP Server.
Tests server functionality when running in SSE (Server-Sent Events) mode over HTTP.
Note: This test requires the server to be running with --transport sse.
"""

import asyncio
import json
import time
from typing import Any

import httpx


class MCPIntegrationTest:
    """Integration test client for Graphiti MCP Server."""

    def __init__(self, base_url: str = 'http://localhost:8000'):
        self.base_url = base_url
        self.client = httpx.AsyncClient(timeout=30.0)
        self.test_group_id = f'test_group_{int(time.time())}'

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.client.aclose()

    async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
        """Call an MCP tool via the SSE endpoint."""
        # MCP protocol message structure
        message = {
            'jsonrpc': '2.0',
            'id': int(time.time() * 1000),
            'method': 'tools/call',
            'params': {'name': tool_name, 'arguments': arguments},
        }

        try:
            response = await self.client.post(
                f'{self.base_url}/message',
                json=message,
                headers={'Content-Type': 'application/json'},
            )

            if response.status_code != 200:
                return {'error': f'HTTP {response.status_code}: {response.text}'}

            result = response.json()
            return result.get('result', result)

        except Exception as e:
            return {'error': str(e)}

    async def test_server_status(self) -> bool:
        """Test the get_status resource."""
        print('🔍 Testing server status...')

        try:
            response = await self.client.get(f'{self.base_url}/resources/http://graphiti/status')
            if response.status_code == 200:
                status = response.json()
                print(f'   ✅ Server status: {status.get("status", "unknown")}')
                return status.get('status') == 'ok'
            else:
                print(f'   ❌ Status check failed: HTTP {response.status_code}')
                return False
        except Exception as e:
            print(f'   ❌ Status check failed: {e}')
            return False

    async def test_add_memory(self) -> dict[str, str]:
        """Test adding various types of memory episodes."""
        print('📝 Testing add_memory functionality...')

        episode_results = {}

        # Test 1: Add text episode
        print('   Testing text episode...')
        result = await self.call_mcp_tool(
            'add_memory',
            {
                'name': 'Test Company News',
                'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
                'source': 'text',
                'source_description': 'news article',
                'group_id': self.test_group_id,
            },
        )

        if 'error' in result:
            print(f'   ❌ Text episode failed: {result["error"]}')
        else:
            print(f'   ✅ Text episode queued: {result.get("message", "Success")}')
            episode_results['text'] = 'success'

        # Test 2: Add JSON episode
        print('   Testing JSON episode...')
        json_data = {
            'company': {'name': 'TechCorp', 'founded': 2010},
            'products': [
                {'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
                {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
            ],
            'employees': 150,
        }

        result = await self.call_mcp_tool(
            'add_memory',
            {
                'name': 'Company Profile',
                'episode_body': json.dumps(json_data),
                'source': 'json',
                'source_description': 'CRM data',
                'group_id': self.test_group_id,
            },
        )

        if 'error' in result:
            print(f'   ❌ JSON episode failed: {result["error"]}')
        else:
            print(f'   ✅ JSON episode queued: {result.get("message", "Success")}')
            episode_results['json'] = 'success'

        # Test 3: Add message episode
        print('   Testing message episode...')
        result = await self.call_mcp_tool(
            'add_memory',
            {
                'name': 'Customer Support Chat',
                'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
                'source': 'message',
                'source_description': 'support chat log',
                'group_id': self.test_group_id,
            },
        )

        if 'error' in result:
            print(f'   ❌ Message episode failed: {result["error"]}')
        else:
            print(f'   ✅ Message episode queued: {result.get("message", "Success")}')
            episode_results['message'] = 'success'

        return episode_results

    async def wait_for_processing(self, max_wait: int = 30) -> None:
        """Wait for episode processing to complete."""
        print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')

        for i in range(max_wait):
            await asyncio.sleep(1)

            # Check if we have any episodes
            result = await self.call_mcp_tool(
                'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
            )

            if not isinstance(result, dict) or 'error' in result:
                continue

            if isinstance(result, list) and len(result) > 0:
                print(f'   ✅ Found {len(result)} processed episodes after {i + 1} seconds')
                return

        print(f'   ⚠️  Still waiting after {max_wait} seconds...')

    async def test_search_functions(self) -> dict[str, bool]:
        """Test search functionality."""
        print('🔍 Testing search functions...')

        results = {}

        # Test search_memory_nodes
        print('   Testing search_memory_nodes...')
        result = await self.call_mcp_tool(
            'search_memory_nodes',
            {
                'query': 'Acme Corp product launch',
                'group_ids': [self.test_group_id],
                'max_nodes': 5,
            },
        )

        if 'error' in result:
            print(f'   ❌ Node search failed: {result["error"]}')
            results['nodes'] = False
        else:
            nodes = result.get('nodes', [])
            print(f'   ✅ Node search returned {len(nodes)} nodes')
            results['nodes'] = True

        # Test search_memory_facts
        print('   Testing search_memory_facts...')
        result = await self.call_mcp_tool(
            'search_memory_facts',
            {
                'query': 'company products software',
                'group_ids': [self.test_group_id],
                'max_facts': 5,
            },
        )

        if 'error' in result:
            print(f'   ❌ Fact search failed: {result["error"]}')
            results['facts'] = False
        else:
            facts = result.get('facts', [])
            print(f'   ✅ Fact search returned {len(facts)} facts')
            results['facts'] = True

        return results

    async def test_episode_retrieval(self) -> bool:
        """Test episode retrieval."""
        print('📚 Testing episode retrieval...')

        result = await self.call_mcp_tool(
            'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
        )

        if 'error' in result:
            print(f'   ❌ Episode retrieval failed: {result["error"]}')
            return False

        if isinstance(result, list):
            print(f'   ✅ Retrieved {len(result)} episodes')

            # Print episode details
            for i, episode in enumerate(result[:3]):  # Show first 3
                name = episode.get('name', 'Unknown')
                source = episode.get('source', 'unknown')
                print(f'     Episode {i + 1}: {name} (source: {source})')

            return len(result) > 0
        else:
            print(f'   ❌ Unexpected result format: {type(result)}')
            return False

    async def test_edge_cases(self) -> dict[str, bool]:
        """Test edge cases and error handling."""
        print('🧪 Testing edge cases...')

        results = {}

        # Test with invalid group_id
        print('   Testing invalid group_id...')
        result = await self.call_mcp_tool(
            'search_memory_nodes',
            {'query': 'nonexistent data', 'group_ids': ['nonexistent_group'], 'max_nodes': 5},
        )

        # Should not error, just return empty results
        if 'error' not in result:
            nodes = result.get('nodes', [])
            print(f'   ✅ Invalid group_id handled gracefully (returned {len(nodes)} nodes)')
            results['invalid_group'] = True
        else:
            print(f'   ❌ Invalid group_id caused error: {result["error"]}')
            results['invalid_group'] = False

        # Test empty query
        print('   Testing empty query...')
        result = await self.call_mcp_tool(
            'search_memory_nodes', {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5}
        )

        if 'error' not in result:
            print('   ✅ Empty query handled gracefully')
            results['empty_query'] = True
        else:
            print(f'   ❌ Empty query caused error: {result["error"]}')
            results['empty_query'] = False

        return results

    async def run_full_test_suite(self) -> dict[str, Any]:
        """Run the complete integration test suite."""
        print('🚀 Starting Graphiti MCP Server Integration Test')
        print(f'   Test group ID: {self.test_group_id}')
        print('=' * 60)

        results = {
            'server_status': False,
            'add_memory': {},
            'search': {},
            'episodes': False,
            'edge_cases': {},
            'overall_success': False,
        }

        # Test 1: Server Status
        results['server_status'] = await self.test_server_status()
        if not results['server_status']:
            print('❌ Server not responding, aborting tests')
            return results

        print()

        # Test 2: Add Memory
        results['add_memory'] = await self.test_add_memory()
        print()

        # Test 3: Wait for processing
        await self.wait_for_processing()
        print()

        # Test 4: Search Functions
        results['search'] = await self.test_search_functions()
        print()

        # Test 5: Episode Retrieval
        results['episodes'] = await self.test_episode_retrieval()
        print()

        # Test 6: Edge Cases
        results['edge_cases'] = await self.test_edge_cases()
        print()

        # Calculate overall success
        memory_success = len(results['add_memory']) > 0
        search_success = any(results['search'].values())
        edge_case_success = any(results['edge_cases'].values())

        results['overall_success'] = (
            results['server_status']
            and memory_success
            and results['episodes']
            and (search_success or edge_case_success)  # At least some functionality working
        )

        # Print summary
        print('=' * 60)
        print('📊 TEST SUMMARY')
        print(f'   Server Status: {"✅" if results["server_status"] else "❌"}')
        print(
            f'   Memory Operations: {"✅" if memory_success else "❌"} ({len(results["add_memory"])} types)'
        )
        print(f'   Search Functions: {"✅" if search_success else "❌"}')
        print(f'   Episode Retrieval: {"✅" if results["episodes"] else "❌"}')
        print(f'   Edge Cases: {"✅" if edge_case_success else "❌"}')
        print()
        print(f'🎯 OVERALL: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')

        if results['overall_success']:
            print('   The refactored MCP server is working correctly!')
        else:
            print('   Some issues detected. Check individual test results above.')

        return results


async def main():
    """Run the integration test."""
    async with MCPIntegrationTest() as test:
        results = await test.run_full_test_suite()

        # Exit with appropriate code
        exit_code = 0 if results['overall_success'] else 1
        exit(exit_code)


if __name__ == '__main__':
    asyncio.run(main())
